package db

import (
	"fmt"
	"github.com/hanc00l/nemo_go/pkg/utils"
	"gorm.io/gorm"
	"time"
)

type Vulnerability struct {
	Id             int       `gorm:"primaryKey"`
	Target         string    `gorm:"column:target"`
	Url            string    `gorm:"column:url"`
	PocFile        string    `gorm:"column:poc_file"`
	Source         string    `gorm:"column:source"`
	Extra          string    `gorm:"column:extra"`
	Hash           string    `gorm:"column:hash"`
	WorkspaceId    int       `gorm:"column:workspace_id"`
	CreateDatetime time.Time `gorm:"column:create_datetime"`
	UpdateDatetime time.Time `gorm:"column:update_datetime"`
}

func (*Vulnerability) TableName() string {
	return "vulnerability"
}

// Add 插入一条新的记录，返回主键ID及成功标志
func (vul *Vulnerability) Add() (success bool) {
	vul.CreateDatetime = time.Now()
	vul.UpdateDatetime = time.Now()
	vul.Hash = utils.MD5(fmt.Sprintf("%s%s%s%s", vul.Target, vul.Url, vul.PocFile, vul.Source))

	db := GetDB()
	defer CloseDB(db)
	if result := db.Create(vul); result.RowsAffected > 0 {
		return true
	} else {
		return false
	}
}

// Get 根据Id查询记录
func (vul *Vulnerability) Get() (success bool) {
	db := GetDB()
	defer CloseDB(db)

	if result := db.First(vul, vul.Id); result.RowsAffected > 0 {
		return true
	} else {
		return false
	}
}

// GetByVulnerability 根据HASH精确查询一条记录
func (vul *Vulnerability) GetByVulnerability() (success bool) {
	hash := utils.MD5(fmt.Sprintf("%s%s%s%s", vul.Target, vul.Url, vul.PocFile, vul.Source))

	db := GetDB()
	defer CloseDB(db)
	if result := db.Where("hash = ?", hash).First(vul); result.RowsAffected > 0 {
		return true
	} else {
		return false
	}
}

// GetsByTarget 根据Target查询记录，返回查询结果数组
func (vul *Vulnerability) GetsByTarget() (results []Vulnerability) {
	orderBy := "update_datetime desc"

	db := GetDB()
	defer CloseDB(db)
	db.Where("target", vul.Target).Order(orderBy).Find(&results)
	return
}

// Update 更新指定ID的一条记录，列名和内容位于map中
func (vul *Vulnerability) Update(updateMap map[string]interface{}) (success bool) {
	updateMap["update_datetime"] = time.Now()

	db := GetDB()
	defer CloseDB(db)
	if result := db.Model(vul).Updates(updateMap); result.RowsAffected > 0 {
		return true
	} else {
		return false
	}
}

// Delete 删除指定主键ID的一条记录
func (vul *Vulnerability) Delete() (success bool) {
	db := GetDB()
	defer CloseDB(db)
	if result := db.Delete(vul, vul.Id); result.RowsAffected > 0 {
		return true
	} else {
		return false
	}
}

// Count 统计指定查询条件的记录数量
func (vul *Vulnerability) Count(searchMap map[string]interface{}) (count int) {
	db := vul.makeWhere(searchMap).Model(vul)
	defer CloseDB(db)
	var result int64
	db.Count(&result)
	return int(result)
}

// makeWhere 根据查询条件的不同的字段，组合生成count和search的查询条件
func (vul *Vulnerability) makeWhere(searchMap map[string]interface{}) *gorm.DB {
	db := GetDB()
	//根据查询条件的不同的字段，组合生成查询条件
	for column, value := range searchMap {
		switch column {
		case "target":
			db = makeLike(value, column, db)
		case "poc_file":
			db = makeLike(value, column, db)
		case "date_delta":
			db = makeDateDelta(value.(int), "update_datetime", db)
		default:
			db = db.Where(column, value)
		}
	}
	return db
}

// Gets 根据指定的条件，查询满足要求的记录
func (vul *Vulnerability) Gets(searchMap map[string]interface{}, page, rowsPerPage int) (results []Vulnerability, count int) {
	orderBy := "update_datetime desc"

	db := vul.makeWhere(searchMap).Model(vul)
	defer CloseDB(db)
	//统计满足条件的总记录数
	var total int64
	db.Count(&total)
	//获取分页查询结果
	if rowsPerPage > 0 && page > 0 {
		db = db.Offset((page - 1) * rowsPerPage).Limit(rowsPerPage)
	}
	db.Order(orderBy).Find(&results)

	return results, int(total)
}

// SaveOrUpdate 保存、更新一条记录
func (vul *Vulnerability) SaveOrUpdate() (success bool, isAdd bool) {
	oldRecord := &Vulnerability{
		Target:  vul.Target,
		Url:     vul.Url,
		PocFile: vul.PocFile,
		Source:  vul.Source,
		Extra:   vul.Extra,
	}
	if oldRecord.GetByVulnerability() {
		updateMap := map[string]interface{}{}
		if vul.Extra != "" {
			updateMap["extra"] = vul.Extra
		}
		vul.Id = oldRecord.Id
		return vul.Update(updateMap), false
	} else {
		return vul.Add(), true
	}
}
